Skip to content

Support skipping tracing of selected pure modules #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 18, 2025
Merged

Conversation

tengyifei
Copy link
Collaborator

@tengyifei tengyifei commented Jun 13, 2025

This PR adds a model.pure_modules option to the trainer, which indicates which modules are run with @assume_pure 1.

The primary benefit is making available rich profiles during the backward pass, because jax.vjp preserves framework scopes in the backward pass.

This PR updates PyTorch/XLA pin to Jun 17 because it relies on pytorch/xla#9360.

Regular profile

python3 torchprime/torch_xla_models/train.py \
    model/sharding=llama-fsdp-tp ici_mesh.fsdp=8 \
    task.global_batch_size=8 model.attention_kernel=splash_attention \
    logging_steps=1 task.max_steps=15

screenshot-2025-06-17-14-15-31

Profile with assume_pure

python3 torchprime/torch_xla_models/train.py \
    model/sharding=llama-fsdp-tp ici_mesh.fsdp=8 \
    task.global_batch_size=8 model.attention_kernel=splash_attention \
    logging_steps=1 task.max_steps=15\
    model.pure_modules=[LlamaMLP,EinsumLinear]

screenshot-2025-06-17-14-15-20

@tengyifei tengyifei force-pushed the yifeit/assume-pure branch 2 times, most recently from cd77464 to d577a40 Compare June 14, 2025 02:07
tengyifei added a commit to pytorch/xla that referenced this pull request Jun 16, 2025
- Support nested tuples in `assume_pure(mark_sharding)`
- Add a `PureModule` from AI-Hypercomputer/torchprime#308
- Support `PureModule(EinsumLinear)` which uses `torch.ops.xla.einsum_linear_forward`
tengyifei added a commit to pytorch/xla that referenced this pull request Jun 16, 2025
- Support nested tuples in `assume_pure(mark_sharding)`
- Add a `PureModule` from AI-Hypercomputer/torchprime#308
- Support `PureModule(EinsumLinear)` which uses `torch.ops.xla.einsum_linear_forward`
tengyifei added a commit to pytorch/xla that referenced this pull request Jun 16, 2025
- Support nested tuples in `assume_pure(mark_sharding)`
- Add a `PureModule` from AI-Hypercomputer/torchprime#308
- Support `PureModule(EinsumLinear)` which uses `torch.ops.xla.einsum_linear_forward`
tengyifei added a commit to pytorch/xla that referenced this pull request Jun 16, 2025
- Support nested tuples in `assume_pure(mark_sharding)`
- Add a `PureModule` from AI-Hypercomputer/torchprime#308
- Support `PureModule(EinsumLinear)` which uses `torch.ops.xla.einsum_linear_forward`
@tengyifei tengyifei force-pushed the yifeit/assume-pure branch from 574f142 to c8ba416 Compare June 17, 2025 08:21
@tengyifei tengyifei changed the title Draft: Yifeit/assume pure Support skipping tracing of selected pure modules Jun 17, 2025
@tengyifei tengyifei force-pushed the yifeit/assume-pure branch 3 times, most recently from 12014b1 to ef24790 Compare June 17, 2025 22:39
@tengyifei tengyifei marked this pull request as ready for review June 17, 2025 23:53
Copy link
Collaborator

@vlasenkoalexey vlasenkoalexey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding this functionality, this is great

@tengyifei tengyifei force-pushed the yifeit/assume-pure branch from 185cd7b to 961d168 Compare June 18, 2025 05:50
@tengyifei tengyifei merged commit 82bf6da into main Jun 18, 2025
15 checks passed
@tengyifei tengyifei deleted the yifeit/assume-pure branch June 18, 2025 06:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants